/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 * contributors may be used to endorse or promote products derived from this
 * software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *//*
 */

/** @file   cutlass_matmul.h
 *  @author Thomas Müller, NVIDIA
 *  @brief  Matrix multiplication wrappers that call into CUTLASS (plus some
 * custom modifications). The parameters are optimized to give optimal
 * performance in a variety of situations. Parts of this file were adapted by
 * starting from the CUTLASS sample code (see its BSD 3-clause license).
 */

#pragma once

#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/functional.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_splitk_parallel.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>

#include <iostream>
#include <map>
#include <type_traits>

#include "utils.h"

//#define TCNN_VERBOSE_MEMORY_ALLOCS

#define CUTLASS_CHECK(status)                                             \
  {                                                                       \
    cutlass::Status error = status;                                       \
    if (error != cutlass::Status::kSuccess) {                             \
      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \
                << " at: " << __LINE__ << std::endl;                      \
      exit(EXIT_FAILURE);                                                 \
    }                                                                     \
  }

#define CUDA_CHECK(status)                                              \
  {                                                                     \
    cudaError_t error = status;                                         \
    if (error != cudaSuccess) {                                         \
      std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
                << " at line: " << __LINE__ << std::endl;               \
      exit(EXIT_FAILURE);                                               \
    }                                                                   \
  }

using SmArch = std::conditional_t<
    MIN_GPU_ARCH >= 80,
    std::conditional_t<std::is_same<network_precision_t, float>::value,
                       cutlass::arch::Sm75, cutlass::arch::Sm80>,
    std::conditional_t<MIN_GPU_ARCH >= 75, cutlass::arch::Sm75,
                       cutlass::arch::Sm70>>;

using TypeAccumulator =
    std::conditional_t<std::is_same<network_precision_t, float>::value, float,
                       cutlass::half_t>;
using TypeCompute =
    std::conditional_t<std::is_same<network_precision_t, float>::value, float,
                       cutlass::half_t>;

template <typename T>
using MMAOp = typename std::conditional<std::is_same<T, float>::value,
                                        cutlass::arch::OpClassSimt,
                                        cutlass::arch::OpClassTensorOp>::type;

template <typename T>
using ShapeMMAOp = typename std::conditional<
    std::is_same<MMAOp<T>, cutlass::arch::OpClassTensorOp>::value,
    typename std::conditional<
        std::is_same<SmArch, cutlass::arch::Sm80>::value ||
            std::is_same<SmArch, cutlass::arch::Sm75>::value,
        cutlass::gemm::GemmShape<16, 8, 8>,
        cutlass::gemm::GemmShape<8, 8, 4>>::type,
    cutlass::gemm::GemmShape<1, 1, 1>>::type;

template <typename thread_block, typename warp>
struct LayerConfig {
  using k_thread_block = thread_block;
  using k_warp = warp;
};

using FullLayerK = typename std::conditional<
    std::is_same<MMAOp<network_precision_t>, cutlass::arch::OpClassSimt>::value,
    LayerConfig<cutlass::gemm::GemmShape<128, 128, 8>,
                cutlass::gemm::GemmShape<32, 64, 8>>,
    LayerConfig<cutlass::gemm::GemmShape<64, 64, 32>,
                cutlass::gemm::GemmShape<32, 32, 32>>>::type;
using LastLayerK = FullLayerK;

using FullLayer = typename std::conditional<
    std::is_same<MMAOp<network_precision_t>, cutlass::arch::OpClassSimt>::value,
    LayerConfig<cutlass::gemm::GemmShape<128, 128, 8>,
                cutlass::gemm::GemmShape<32, 64, 8>>,
    LayerConfig<cutlass::gemm::GemmShape<128, 128, 32>,
                cutlass::gemm::GemmShape<64, 64, 32>>>::type;
using LastLayer = FullLayer;

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;

// This code section describes the epilogue part of the kernel

template <typename V>
struct CutlassFragmentWrapper {
  static const uint32_t num_elements = V::kElements;
  V x;
};

template <
    typename ElementOutput_,  ///< Data type used to load and store tensors
    int Count,                ///< Number of elements computed per operation
    typename ElementAccumulator_ = ElementOutput_,  ///< Accumulator data type
    typename ElementCompute_ =
        ElementOutput_,  ///< Data type used to compute linear combination
    cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest>
class ActivationEpilogue {
 public:
  using ElementOutput = ElementOutput_;
  using ElementAccumulator = ElementAccumulator_;
  using ElementCompute = ElementCompute_;

  static int const kCount = Count;

  using FragmentOutput = cutlass::Array<ElementOutput, kCount>;
  using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;
  using ComputeFragment = cutlass::Array<ElementCompute, kCount>;

  static cutlass::FloatRoundStyle const kRound = Round;

  struct Params {
    Activation activation;
    bool sum_source;
  };

 public:
  CUTLASS_HOST_DEVICE
  ActivationEpilogue(Params const& params)
      : m_activation{params.activation}, m_sum_source{params.sum_source} {}

  CUTLASS_HOST_DEVICE
  bool is_source_needed() const { return m_sum_source; }

  /// Functionally required for serial reduction in the epilogue
  CUTLASS_HOST_DEVICE
  void set_k_partition(int k_partition, int k_partition_count) {}

  CUTLASS_HOST_DEVICE
  FragmentOutput operator()(FragmentAccumulator const& accumulator) const {
    cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator, kCount,
                                   Round>
        accumulator_converter;

    auto intermediate = CutlassFragmentWrapper<ComputeFragment>{
        accumulator_converter(accumulator)};
    intermediate = warp_activation<ElementCompute>(m_activation, intermediate);

    cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
        destination_converter;
    return destination_converter(intermediate.x);
  }

  CUTLASS_HOST_DEVICE
  FragmentOutput operator()(FragmentAccumulator const& accumulator,
                            FragmentOutput const& source) const {
    cutlass::NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
        source_converter;
    cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator, kCount,
                                   Round>
        accumulator_converter;

    cutlass::plus<ComputeFragment> plus_op;
    auto intermediate = CutlassFragmentWrapper<ComputeFragment>{
        accumulator_converter(accumulator)};
    if (m_sum_source) {
      intermediate.x = plus_op(intermediate.x, source_converter(source));
    }
    intermediate = warp_activation<ElementCompute>(m_activation, intermediate);

    cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
        destination_converter;
    return destination_converter(intermediate.x);
  }

 private:
  Activation m_activation;
  bool m_sum_source;
};

template <
    typename ElementOutput_,  ///< Data type used to load and store tensors
    int Count,                ///< Number of elements computed per operation
    typename ElementAccumulator_ = ElementOutput_,  ///< Accumulator data type
    typename ElementCompute_ =
        ElementOutput_,  ///< Data type used to compute linear combination
    cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest>
class ActivationTransferEpilogue {
 public:
  using ElementOutput = ElementOutput_;
  using ElementAccumulator = ElementAccumulator_;
  using ElementCompute = ElementCompute_;

  static int const kCount = Count;

  using FragmentOutput = cutlass::Array<ElementOutput, kCount>;
  using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;
  using ComputeFragment = cutlass::Array<ElementCompute, kCount>;

  static cutlass::FloatRoundStyle const kRound = Round;

  /// Host-constructable parameters structure
  struct Params {
    Activation activation;
  };

 public:
  /// Constructs the function object, possibly loading from pointers in host
  /// memory
  CUTLASS_HOST_DEVICE
  ActivationTransferEpilogue(Params const& params)
      : m_activation{params.activation} {}

  /// Returns true if source is needed
  CUTLASS_HOST_DEVICE
  bool is_source_needed() const { return true; }

  /// Functionally required for serial reduction in the epilogue
  CUTLASS_HOST_DEVICE
  void set_k_partition(int k_partition, int k_partition_count) {}

  CUTLASS_HOST_DEVICE
  FragmentOutput operator()(FragmentAccumulator const& accumulator,
                            FragmentOutput const& source) const {
    cutlass::NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
        source_converter;
    cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator, kCount,
                                   Round>
        accumulator_converter;

    auto converted_source =
        CutlassFragmentWrapper<ComputeFragment>{source_converter(source)};
    auto intermediate = CutlassFragmentWrapper<ComputeFragment>{
        accumulator_converter(accumulator)};

    intermediate = warp_activation_backward<ElementCompute>(
        m_activation, intermediate, converted_source);

    cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
        destination_converter;
    return destination_converter(intermediate.x);
  }

  CUTLASS_HOST_DEVICE
  FragmentOutput operator()(FragmentAccumulator const& accumulator) const {
    cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator, kCount,
                                   Round>
        accumulator_converter;

    ComputeFragment converted_accumulator = accumulator_converter(accumulator);

    cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
        destination_converter;

    return destination_converter(converted_accumulator);
  }

 private:
  Activation m_activation;
};

template <typename T>
static constexpr int n_vectorized_elements =
    std::is_same<MMAOp<T>, cutlass::arch::OpClassTensorOp>::value
        ? (128 / cutlass::sizeof_bits<T>::value)
        : 1;

template <typename T>
using SumOp =
    cutlass::epilogue::thread::LinearCombination<T, n_vectorized_elements<T>,
                                                 TypeAccumulator, TypeCompute>;

template <typename T>
using IntermediateActivationOp =
    ActivationEpilogue<T, 4, TypeAccumulator, TypeCompute>;

template <typename T>
using IntermediateActivationTransferOp =
    ActivationTransferEpilogue<T, 4, TypeAccumulator, TypeCompute>;

template <typename T>
using ActivationOp = ActivationEpilogue<T, n_vectorized_elements<T>,
                                        TypeAccumulator, TypeCompute>;

template <typename T>
using ActivationTransferOp =
    ActivationTransferEpilogue<T, n_vectorized_elements<T>, TypeAccumulator,
                               TypeCompute>;

template <typename EPILOGUE, typename LayerConfig, typename TypeA,
          typename LayoutA, typename TypeB, typename LayoutB,
          typename TypeOutput, typename LayoutOutput>
using OurGemm =
    cutlass::gemm::device::Gemm<TypeA, LayoutA, TypeB, LayoutB, TypeOutput,
                                LayoutOutput, TypeAccumulator, MMAOp<TypeA>,
                                SmArch, typename LayerConfig::k_thread_block,
                                typename LayerConfig::k_warp, ShapeMMAOp<TypeA>,
                                EPILOGUE, SwizzleThreadBlock, 2>;

template <typename EPILOGUE, typename LayerConfig, typename TypeA,
          typename LayoutA, typename TypeB, typename LayoutB,
          typename TypeOutput, typename LayoutOutput>
using SplitKGemm = cutlass::gemm::device::GemmSplitKParallel<
    TypeA, LayoutA, TypeB, LayoutB, TypeOutput, LayoutOutput, TypeAccumulator,
    MMAOp<TypeA>, SmArch, typename LayerConfig::k_thread_block,
    typename LayerConfig::k_warp, ShapeMMAOp<TypeA>, EPILOGUE>;

inline std::map<cudaStream_t, GPUMemory<uint8_t>>& cutlass_workspaces() {
  static std::map<cudaStream_t, GPUMemory<uint8_t>> s_workspaces;
  return s_workspaces;
}

inline uint8_t* cutlass_get_workspace(size_t size, cudaStream_t stream) {
  GPUMemory<uint8_t>& workspace = cutlass_workspaces()[stream];
  if (size > workspace.size()) {
    size *= 2;
#ifdef TCNN_VERBOSE_MEMORY_ALLOCS
    std::cout << "CUTLASS GEMM: Allocating temporary workspace of "
              << bytes_to_string(size) << "." << std::endl;
#endif

    // Allocate twice the requested size to make sure we're not constantly
    // allocating small increments.
    workspace.resize(size);
  }
  return workspace.data();
}

inline void cutlass_free_workspace(cudaStream_t stream) {
  if (cutlass_workspaces().count(stream) == 0) {
    return;
  }

#ifdef TCNN_VERBOSE_MEMORY_ALLOCS
  std::cout << "CUTLASS GEMM: Freeing temporary workspace of "
            << bytes_to_string(cutlass_workspaces().at(stream).size()) << "."
            << std::endl;
#endif
  cutlass_workspaces().erase(stream);
}

template <class Gemm>
void fc_multiply_impl(cudaStream_t stream,
                      const typename Gemm::Arguments& args) {
  // Using the arguments, query for extra workspace required for matrix
  // multiplication computation
  size_t workspace_size = Gemm::get_workspace_size(args);

  // Instantiate CUTLASS kernel depending on templates
  Gemm gemm_op;

  // Initialize CUTLASS kernel with arguments and workspace pointer
  cutlass::Status status = gemm_op.initialize(
      args, cutlass_get_workspace(workspace_size, stream), stream);
  CUTLASS_CHECK(status);

  // Launch initialized CUTLASS kernel
  status = gemm_op(stream);
  CUTLASS_CHECK(status);
}

template <class Gemm>
void fc_multiply_split_k_impl(cudaStream_t stream,
                              const typename Gemm::Arguments& args) {
  // Using the arguments, query for extra workspace required for matrix
  // multiplication computation
  size_t workspace_size = Gemm::get_workspace_size(args);

  // Instantiate CUTLASS kernel depending on templates
  Gemm gemm_op;

  // Initialize CUTLASS kernel with arguments and workspace pointer
  cutlass::Status status =
      gemm_op.initialize(args, cutlass_get_workspace(workspace_size, stream));
  CUTLASS_CHECK(status);

  // Launch initialized CUTLASS kernel
  status = gemm_op(stream);
  CUTLASS_CHECK(status);
}

//////////////////////////////////////////////////////////////////////////////////
////////////////////////////        modified ///////////////////////////////
//////////////////////////////////////////////////////////////////////////////////

template <typename config, bool RM_A, bool RM_B, bool RM_C>
void fc_multiply(cudaStream_t stream, const int M, const int K, const int N,
                 const __half* A, const __half* B, const __half* C, __half* D,
                 Activation act = Activation::None, bool transfer = false,
                 bool sum_source = false) {
  // compute  D = A @ B + C
  // A: [M, K]
  // B: [K, N]
  // C, D: [M, N]
  using CutlassLayoutA =
      typename std::conditional<RM_A, cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor>::type;
  using CutlassLayoutB =
      typename std::conditional<RM_B, cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor>::type;
  using CutlassLayoutC =
      typename std::conditional<RM_C, cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor>::type;

  using MatmulTypeCompute = cutlass::half_t;
  using MatmulTypeAccumulator = cutlass::half_t;

  const int lda = RM_A ? K : M;
  const int ldb = RM_B ? N : K;
  const int ldc = RM_C ? N : M;
  const int ldd = RM_C ? N : M;

  if (transfer) {
    using Gemm = OurGemm<ActivationTransferOp<MatmulTypeAccumulator>, config,
                         MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute,
                         CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>;
    typename Gemm::Arguments arguments{{M, N, K},
                                       {(MatmulTypeCompute*)A, lda},
                                       {(MatmulTypeCompute*)B, ldb},
                                       {(MatmulTypeAccumulator*)C, ldc},
                                       {(MatmulTypeAccumulator*)D, ldd},
                                       {act},
                                       1};

    fc_multiply_impl<Gemm>(stream, arguments);
  } else {
    using Gemm = OurGemm<ActivationOp<MatmulTypeAccumulator>, config,
                         MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute,
                         CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>;
    typename Gemm::Arguments arguments{{M, N, K},
                                       {(MatmulTypeCompute*)A, lda},
                                       {(MatmulTypeCompute*)B, ldb},
                                       {(MatmulTypeAccumulator*)C, ldc},
                                       {(MatmulTypeAccumulator*)D, ldd},
                                       {act, sum_source},
                                       1};

    fc_multiply_impl<Gemm>(stream, arguments);
  }
}

template <typename config, bool RM_A, bool RM_B, bool RM_C>
void fc_multiply(cudaStream_t stream, const int M, const int K, const int N,
                 const __half* A, const __half* B, __half* D,
                 Activation act = Activation::None) {
  fc_multiply<config, RM_A, RM_B, RM_C>(stream, M, K, N, A, B, D, D, act);
}

template <typename config, bool RM_A, bool RM_B, bool RM_C>
void fc_multiply_split_k(cudaStream_t stream, const int M, const int K,
                         const int N, const __half* A, const __half* B,
                         const __half* C, __half* D, int split_k_slices = 1) {
  // A: [M, K]
  // B: [K, N]
  // C, D: [M, N]
  using CutlassLayoutA =
      typename std::conditional<RM_A, cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor>::type;
  using CutlassLayoutB =
      typename std::conditional<RM_B, cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor>::type;
  using CutlassLayoutC =
      typename std::conditional<RM_C, cutlass::layout::RowMajor,
                                cutlass::layout::ColumnMajor>::type;

  using MatmulTypeCompute = cutlass::half_t;
  using MatmulTypeAccumulator = cutlass::half_t;

  const int lda = RM_A ? K : M;
  const int ldb = RM_B ? N : K;
  const int ldc = RM_C ? N : M;
  const int ldd = RM_C ? N : M;

  using Gemm =
      SplitKGemm<SumOp<MatmulTypeAccumulator>, config, MatmulTypeCompute,
                 CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB,
                 MatmulTypeAccumulator, CutlassLayoutC>;

  typename Gemm::Arguments arguments{{M, N, K},
                                     {(MatmulTypeCompute*)A, lda},
                                     {(MatmulTypeCompute*)B, ldb},
                                     {(MatmulTypeAccumulator*)C, ldc},
                                     {(MatmulTypeAccumulator*)D, ldd},
                                     {(TypeCompute)1.0f, (TypeCompute)0.0f},
                                     split_k_slices};

  fc_multiply_split_k_impl<Gemm>(stream, arguments);
}

template <typename config, bool RM_A, bool RM_B, bool RM_C>
void fc_multiply_split_k(cudaStream_t stream, const int M, const int K,
                         const int N, const __half* A, const __half* B,
                         __half* D, int split_k_slices = 1) {
  fc_multiply_split_k<config, RM_A, RM_B, RM_C>(stream, M, K, N, A, B, D, D,
                                                split_k_slices);
}
